import numpy as np
from amd._nearest_neighbours import nearest_neighbours_data

from gymnasium.spaces import Box

from collections import deque
import numpy as np
from itertools import chain

class Normalizer:
    def __init__(self, max_size):
        self.queue = deque(maxlen=max_size)
        self.sum = 0
        self.sum_of_squares = 0
        self.count = 0

    def add_numbers(self, numbers):
        for num in numbers:
            if len(self.queue) == self.queue.maxlen:  # If full, remove the oldest number
                oldest = self.queue.popleft()
                self.sum -= oldest
                self.sum_of_squares -= oldest ** 2
                self.count -= 1

            # Add the new number
            self.queue.append(num)
            self.sum += num
            self.sum_of_squares += num ** 2
            self.count += 1

    def get_mean_std(self):

        if self.count == 0:
            return 0, 0
        
        mean = self.sum / self.count if self.count > 0 else 0
        variance = (self.sum_of_squares / self.count) - mean ** 2
        return mean, np.sqrt(variance)


class CombinedFlatObserver:
    def __init__(self, world, agents, **kwargs):
        self.world = world
        self.last_observation = {}

        self.agents = agents


    def observation_space(self, env_observation_len=0):
        # world_observation_len = len(self.world.agents[flat_agent].observation())
        # obs_len = world_observation_len + 
        
        node_f_length = list(self.world.agents.values())[0].features().shape[0]
        edge_f_length = 4
            
        flat_obs_len = len(self.world.features()) + env_observation_len # world/env features
        flat_obs_len += (self.world.neighbors_limit + 1) * node_f_length # node features
        flat_obs_len += self.world.neighbors_limit * edge_f_length # edge features
        
        return Box(low=-np.inf, high=np.inf, shape=(flat_obs_len,), dtype=np.float32)


    def update_neighbours(self):

        def offset(image_positions, base_positions, lattice_vecs):
            A = lattice_vecs.transpose()
            b = image_positions - base_positions
            j_cell = (np.linalg.solve(A, b.transpose())).transpose()
            return j_cell

        atoms_positions = np.array(self.world.atoms.positions)
        cell_vecs = np.array(self.world.cell)

        dists, cloud, inds = nearest_neighbours_data(motif=atoms_positions, 
                                      cell=cell_vecs,
                                      x=atoms_positions,
                                      k=self.world.neighbors_limit + 1)
        
        num_atoms = len(self.world.atoms.positions)
        cloud_idx = inds[:, 1:].flatten() # remove the first of each record because this is the self edge
        j_idx = cloud_idx % num_atoms
        j_carts = cloud[cloud_idx]
        base_j_carts = self.world.atoms.positions[j_idx]
        j_cell = offset(j_carts, base_j_carts, cell_vecs)
        j_cell = j_cell.reshape(num_atoms, self.world.neighbors_limit, 3)

        # i_idx = np.arange(num_atoms).repeat(self.world.neighbors_limit)
        edges = j_idx.reshape(num_atoms, self.world.neighbors_limit)
        dists = dists[:, 1:]

        return edges, j_cell, dists


    def update_observations(self, env_obs=None):
        
        agent_features = {}

        for agent in self.agents.keys():
            agent_features[self.agents[agent]] = list(self.world.agents[self.agents[agent]].features())

        env_world_features = self.world.features()
        if env_obs:
            env_world_features.extend(env_obs)

        positions_m = self.world.atoms.positions
        offsets_m = np.zeros((len(self.agents), self.world.neighbors_limit, 3), dtype=np.float32)
        neighbors_m = np.zeros((len(self.agents), self.world.neighbors_limit), dtype=np.int32)
        cell = np.array(self.world.cell)

        neighbors_m, offsets_m, dists = self.update_neighbours()

        # Compute displacement vectors in a vectorized way
        displacement = positions_m[neighbors_m] - positions_m[:, np.newaxis, :]  # Shape: (N, K, 3)

        # Compute periodic shift
        periodic_shift = np.einsum('nkj,ij->nki', offsets_m, cell)  # Shape: (N, K, 3)

        # Compute final vectors
        vectors_to = displacement + periodic_shift  # Shape: (N, K, 3)

        for a_idx, agent in enumerate(self.agents.keys()):

            flat_agent = self.agents[agent]
            list_of_lists = [agent_features[flat_agent]]

            for n in range(len(dists[a_idx])):

                j_idx = neighbors_m[a_idx][n]
                j_agent_name = self.world.names_by_indices[j_idx]
                list_of_lists.extend([[dists[a_idx][n]], agent_features[j_agent_name]])
                
                list_of_lists.append(vectors_to[a_idx, n])

            list_of_lists.append(env_world_features)

            self.last_observation[agent] = np.asarray(list(chain(*list_of_lists)))


    def observe(self, agent, as_array=True): # TODO unify interface for observers
        if as_array:
            return np.float32(np.asarray(self.last_observation[agent]))
        return self.last_observation[agent]
